Skip to content

[perf] Add zmq.proxy to accelerate request processing for SimpleStorageUnit#37

Merged
0oshowero0 merged 9 commits intoAscend:mainfrom
0oshowero0:multithread
Feb 28, 2026
Merged

[perf] Add zmq.proxy to accelerate request processing for SimpleStorageUnit#37
0oshowero0 merged 9 commits intoAscend:mainfrom
0oshowero0:multithread

Conversation

@0oshowero0
Copy link
Collaborator

@0oshowero0 0oshowero0 commented Feb 26, 2026

Background

Previously, SimpleStorageUnit relied on a single-threaded event loop for request processing. This design could lead to bottlenecks and increased latency when multiple requests arrived simultaneously, as operations like ZMQ message deserialization and memory I/O would block the main socket loop from receiving new requests.

Key Changes

  1. Refactored SimpleStorageUnit to utilize a native zmq.proxy. This acts as a highly efficient, C-level load balancer between a frontend ROUTER socket (handling external client connections) and an internal backend DEALER socket (inproc://).
  2. Introduced a worker thread pool where each worker binds its own independent DEALER socket to process PUT/GET/CLEAR requests concurrently. This preserves ZMQ's "share-nothing" concurrency philosophy.
  3. Added a threading.Lock() to StorageUnitData to prevent race condition introduced by multi-threads
  4. Added num_worker_threads as an explicit input parameter for SimpleStorageUnit (configurable via TQ system config items).

During performance test, we surprisingly find out that the refactored multi-thread code achieves better performance with num_worker_threads=1. The introduction of the native C-level zmq.proxy offloads the high-frequency I/O from the main Python thread. Therefore, we retire the multi-thread version and only preserve the zmq.proxy optimization.

Architechture

Old Version

mermaid-diagram-2026-02-26-192209

New Version

mermaid-diagram-2026-02-26-220631

Performance Gain

We provide a simple benchmark script for this PR:

import argparse
import multiprocessing
import time
import ray
import torch
import zmq
import tensordict

# Ensure this runs in the repository root directory, otherwise sys.path.append might be needed
from transfer_queue.storage.simple_backend import SimpleStorageUnit
from transfer_queue.utils.zmq_utils import ZMQMessage, ZMQRequestType


class StorageClient:
    """Independent test client that interacts directly with the frontend ROUTER of SimpleStorageUnit"""

    def __init__(self, address):
        self.context = zmq.Context()
        self.socket = self.context.socket(zmq.DEALER)
        self.socket.setsockopt(zmq.RCVTIMEO, 20000)  # Timeout set to 20s to prevent timeouts under heavy concurrency
        self.socket.connect(address)

    def send_put(self, client_id, local_indexes, field_data):
        msg = ZMQMessage.create(
            request_type=ZMQRequestType.PUT_DATA,
            sender_id=f"bench_client_{client_id}",
            body={"local_indexes": local_indexes, "data": field_data},
        )
        self.socket.send_multipart(msg.serialize())
        return ZMQMessage.deserialize(self.socket.recv_multipart())

    def close(self):
        self.socket.close()
        self.context.term()


def client_worker(worker_id, address, num_requests, batch_size):
    """Worker process task: Continuously bombard the Storage Unit with PUT requests"""
    client = StorageClient(address)
    start_time = time.time()

    # Construct Dummy Tensor data to simulate actual memory and serialization overhead
    # As noted in the PR description, serialization and memory I/O are the bottlenecks blocking the main loop
    field_data = {
        "dummy_tensor": [torch.randn(256, 256) for _ in range(batch_size)]
    }

    for i in range(num_requests):
        local_indexes = list(range(i * batch_size, (i + 1) * batch_size))
        client.send_put(worker_id, local_indexes, field_data)

    elapsed = time.time() - start_time
    client.close()

    print(f"[Worker {worker_id}] Completed {num_requests} write requests, took {elapsed:.3f} seconds "
          f"(QPS: {num_requests / elapsed:.2f} req/s)")


def main(num_clients, storage_threads, requests_per_client):
    # Initialize Ray and global settings
    ray.init(ignore_reinit_error=True)
    tensordict.set_list_to_stack(True).set()

    try:
        print(f"🚀 Launching SimpleStorageUnit, internal worker threads (num_worker_threads): {storage_threads} ...")

        # Launch the backend Actor. PR 37 exposes the num_worker_threads parameter
        storage_actor = SimpleStorageUnit.options(
            max_concurrency=50, num_cpus=2
        ).remote(
            storage_unit_size=1000000,
            num_worker_threads=storage_threads # comment this line for old version comparison
        )

        zmq_info = ray.get(storage_actor.get_zmq_server_info.remote())
        put_get_address = zmq_info.to_addr("put_get_socket")
        print(f"✅ Storage unit ready, ZMQ Address: {put_get_address}")

        # Wait for zmq.proxy and all worker threads to bind to the inproc port
        time.sleep(2)

        print(f"🔥 Spawning {num_clients} independent concurrent write processes...")
        processes = []
        batch_size = 256

        start_time = time.time()

        # 1. Create and start multiple processes
        for i in range(num_clients):
            p = multiprocessing.Process(
                target=client_worker,
                args=(i, put_get_address, requests_per_client, batch_size)
            )
            p.start()
            processes.append(p)

        # 2. Wait for all concurrent processes to complete
        for p in processes:
            p.join()

        total_time = time.time() - start_time
        total_requests = num_clients * requests_per_client

        print("\n" + "=" * 50)
        print(f" 📊 Benchmark Results")
        print("=" * 50)
        print(f" SimpleStorageUnit internal threads : {storage_threads}")
        print(f" External concurrent clients        : {num_clients}")
        print(f" Total processed requests (Batches) : {total_requests} (Batch Size: {batch_size})")
        print(f" Total benchmark duration           : {total_time:.3f} seconds")
        print(f" 🚀 Overall Throughput              : {total_requests / total_time:.2f} req/s")
        print("=" * 50 + "\n")

    finally:
        # Resource cleanup
        if 'storage_actor' in locals():
            ray.kill(storage_actor)
        ray.shutdown()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="PR #37 Performance Benchmark")
    parser.add_argument("--clients", type=int, default=8, help="Number of concurrent client processes")
    parser.add_argument("--threads", type=int, default=4, help="Number of processing threads in SimpleStorageUnit")
    parser.add_argument("--requests", type=int, default=300, help="Number of requests sent per client")

    args = parser.parse_args()
    main(args.clients, args.threads, args.requests)

Small Scale Test (batch_size=20, clients=4)

On a mac mini with M2 chip with 24GB memory:

Old Version

python benchmark.py --clients 4
image

New Version

python benchmark.py --clients 4 --threads 1
image
python benchmark.py --clients 4 --threads 2
image

Middle Scale Test (batch_size=256, clients=4)

On a mac mini with M2 chip with 24GB memory:

Old Version

python benchmark.py --clients 4
image

New Version

python benchmark.py --clients 4 --threads 1
image
python benchmark.py --clients 4 --threads 2
image

Large Scale Test (batch_size=256, clients=50)

On a Ubuntu server with Intel(R) Xeon(R) Platinum 8358P CPU @ 2.60GHz x 128 cores:

Note:

  1. The benchmark script has also been modified to consider get performance
  2. We export the following env vars:
export OMP_NUM_THREADS=1
export MKL_NUM_THREADS=1
export OPENBLAS_NUM_THREADS=1
export VECLIB_MAXIMUM_THREADS=1
export NUMEXPR_NUM_THREADS=1
export TORCH_NUM_THREADS=1
export TQ_ZERO_COPY_SERIALIZATION=True

Old Version

python benchmark.py --clients 50
image

New Version

python benchmark.py --clients 50 --threads 1
image
python benchmark.py --clients 50 --threads 2
image
python benchmark.py --clients 50 --threads 4
image

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Copilot AI review requested due to automatic review settings February 26, 2026 08:37
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request refactors SimpleStorageUnit from a single-threaded architecture to a multi-threaded design using ZMQ proxy and worker thread pools. The change aims to improve performance by allowing concurrent request processing, eliminating bottlenecks from sequential message handling.

Changes:

  • Introduced a native zmq.proxy to load-balance between a frontend ROUTER socket and backend DEALER socket
  • Added a worker thread pool where each worker processes PUT/GET/CLEAR requests concurrently
  • Added thread synchronization primitives (Lock, Event) to StorageUnitData and SimpleStorageUnit for race condition prevention
  • Added num_worker_threads configuration parameter (default: 4) to control worker pool size

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 5 comments.

File Description
transfer_queue/storage/simple_backend.py Core refactoring: replaced single-threaded event loop with zmq.proxy and worker pool; added locks and shutdown mechanisms; updated message handling for multi-threaded routing
transfer_queue/interface.py Added num_worker_threads parameter extraction from config and passes it to SimpleStorageUnit initialization
transfer_queue/config.yaml Added num_worker_threads configuration option with default value of 4
Comments suppressed due to low confidence (2)

transfer_queue/storage/simple_backend.py:106

  • The get_data method lacks thread safety protection. While it reads from self.field_data, concurrent writes from put_data (line 131) or modifications from clear could cause race conditions, leading to inconsistent reads or crashes. The lock should be acquired before accessing self.field_data to ensure thread-safe reads in the multi-threaded environment.
    def get_data(self, fields: list[str], local_indexes: list[int]) -> dict[str, list]:
        """
        Get data from storage unit according to given fields and local_indexes.

        Args:
            fields: Field names used for getting data.
            local_indexes: Local indexes used for getting data.

        Returns:
            dict with field names as keys, corresponding data list as values.
        """
        result: dict[str, list] = {}

        for field in fields:
            # Validate field name
            if field not in self.field_data:
                raise ValueError(
                    f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
                )

            if len(local_indexes) == 1:
                gathered_item = self.field_data[field][local_indexes[0]]
                result[field] = [gathered_item]

            else:
                gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))

                result[field] = gathered_items

        return result

transfer_queue/storage/simple_backend.py:131

  • The double-checked locking pattern used here is not thread-safe in Python. Between the outer check (line 119) and acquiring the lock (line 120), another thread could initialize the field, but more critically, the writes to list elements at line 131 are not protected by any lock. This creates race conditions where concurrent threads can simultaneously write to the same index. The entire write operation after field initialization should be protected by the lock, or use a different synchronization approach.
        for f, values in field_data.items():
            # Double-checked locking for field initialization
            if f not in self.field_data:
                with self._lock:
                    if f not in self.field_data:
                        self.field_data[f] = [None] * self.storage_size

            for i, idx in enumerate(local_indexes):
                if idx < 0 or idx >= self.storage_size:
                    raise ValueError(
                        f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
                        f"storage_size: {self.storage_size}"
                    )

                self.field_data[f][idx] = values[i]

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 247 to 259
def _start_process_put_get(self) -> None:
"""Create a daemon thread and start put/get process."""
self.process_put_get_thread = Thread(
target=self._process_put_get, name=f"StorageUnitProcessPutGetThread-{self.storage_unit_id}", daemon=True
"""Start worker threads and ZMQ proxy for handling requests."""

# Start worker threads
for i in range(self.num_workers):
worker_thread = Thread(
target=self._worker_routine,
args=(i,),
name=f"StorageUnitWorkerThread-{self.storage_unit_id}-{i}",
daemon=True,
)
worker_thread.start()
self.worker_threads.append(worker_thread)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multi-threaded worker pool is a significant architectural change that introduces concurrency concerns, but there are no tests that explicitly verify thread-safety or concurrent access patterns. The existing test test_put_get_multiple_clients tests multiple clients sequentially, not truly concurrent operations. Consider adding tests that use threading or asyncio to send simultaneous PUT/GET/CLEAR operations to verify thread-safety, race condition handling, and that the worker pool correctly processes concurrent requests without data corruption.

Copilot uses AI. Check for mistakes.
Comment on lines 206 to 213
self._finalizer = weakref.finalize(
self,
self._shutdown_resources,
self._shutdown_event,
self.worker_threads,
self.proxy_thread,
self.zmq_context,
)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The finalizer is registered before the threads and zmq_context are fully initialized. At this point (line 206), self.worker_threads is an empty list, self.proxy_thread is None, and self.zmq_context is None. These values are captured by the finalizer at registration time, not at cleanup time. When garbage collection occurs, the finalizer will attempt to shut down the wrong (empty/None) references instead of the actual running threads and context. The finalizer should be registered after _init_zmq_socket() and _start_process_put_get() complete, or it should pass self and access attributes dynamically.

Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Comment on lines 119 to 122
if f not in self.field_data:
self.field_data[f] = [None] * self.storage_size
with self._lock:
if f not in self.field_data:
self.field_data[f] = [None] * self.storage_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we use double-check? To avoid overhead of locking?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be multiple threads go into the outer if case and try to acquire lock

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 5 out of 5 changed files in this pull request and generated 8 comments.

Comments suppressed due to low confidence (2)

transfer_queue/storage/simple_backend.py:131

  • The double-checked locking pattern for field initialization is not sufficient to prevent race conditions on the actual data writes at line 131. After the field is initialized (lines 118-122), multiple threads can concurrently write to self.field_data[f][idx] without synchronization. This can lead to:
  1. Lost updates: If two threads write to the same index simultaneously, one write may be lost.
  2. Data corruption: Depending on Python's GIL behavior and the data types involved (especially with complex objects like tensors), concurrent writes could corrupt data.

While the GIL provides some protection for simple operations, it's not guaranteed for all scenarios, especially with C extensions (like PyTorch tensors). Consider protecting the entire write operation (lines 124-131) with a lock, or using finer-grained locks per field or index range.

            for i, idx in enumerate(local_indexes):
                if idx < 0 or idx >= self.storage_size:
                    raise ValueError(
                        f"StorageUnitData put_data operation receive invalid local_index: {idx} beyond "
                        f"storage_size: {self.storage_size}"
                    )

                self.field_data[f][idx] = values[i]

transfer_queue/storage/simple_backend.py:106

  • The get_data method is not thread-safe in a multi-threaded environment. While put_data and clear use locks for write operations, get_data performs reads without any synchronization. This can lead to race conditions where:
  1. A worker thread reads from self.field_data[field] while another thread is modifying the dictionary structure in put_data (field initialization) or clear.
  2. The check if field not in self.field_data at line 92 followed by access at line 98/102 is not atomic, so a concurrent clear operation could modify the dictionary between the check and access.
  3. The itemgetter operation at line 102 could read partially updated data if another thread is writing to the same indexes.

Consider acquiring a read lock or using a readers-writer lock pattern to protect read operations, especially around dictionary access and list indexing.

    def get_data(self, fields: list[str], local_indexes: list[int]) -> dict[str, list]:
        """
        Get data from storage unit according to given fields and local_indexes.

        Args:
            fields: Field names used for getting data.
            local_indexes: Local indexes used for getting data.

        Returns:
            dict with field names as keys, corresponding data list as values.
        """
        result: dict[str, list] = {}

        for field in fields:
            # Validate field name
            if field not in self.field_data:
                raise ValueError(
                    f"StorageUnitData get_data operation receive invalid field: {field} beyond {self.field_data.keys()}"
                )

            if len(local_indexes) == 1:
                gathered_item = self.field_data[field][local_indexes[0]]
                result[field] = [gathered_item]

            else:
                gathered_items = list(itemgetter(*local_indexes)(self.field_data[field]))

                result[field] = gathered_items

        return result

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

data_fields = []
for fname, col_idx in partition.field_name_mapping.items():
if col_mask[col_idx]:
if col_idx < len(col_mask) and col_mask[col_idx]:
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change adds a bounds check to prevent an IndexError when col_idx is out of range for col_mask. While this is a good defensive fix, it appears to be unrelated to the multi-threading changes in this PR.

Consider:

  1. Moving this fix to a separate PR for easier tracking and review.
  2. Adding a comment explaining under what conditions col_idx might exceed len(col_mask), as this could indicate a data consistency issue elsewhere in the code.
  3. Adding a warning log when this condition is detected to help identify the root cause.

Copilot uses AI. Check for mistakes.
Comment on lines +233 to +235
# Backend: DEALER for worker communication (connected via zmq.proxy)
self.worker_socket = create_zmq_socket(self.zmq_context, zmq.DEALER)
self.worker_socket.bind(self._inproc_addr)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The worker socket binding happens in _init_zmq_socket (line 235), but worker threads connect to this address in _worker_routine (line 283). There's a potential race condition where worker threads might try to connect before the backend socket is fully bound and ready.

While there's a retry mechanism in the bind operation (lines 225-231), adding a small delay or verification after the bind at line 235 would ensure the socket is ready before starting worker threads. Alternatively, consider binding the backend socket before starting any threads, or add connection retry logic in the worker threads.

Copilot uses AI. Check for mistakes.
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

@0oshowero0 0oshowero0 changed the title [perf] Add multi-thread request processing for SimpleStorageUnit [perf] Add zmq.proxy to accelerate request processing for SimpleStorageUnit Feb 27, 2026
Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Comment on lines 460 to 465
if proxy_thread and proxy_thread.is_alive():
proxy_thread.join(timeout=5)

# Terminate ZMQ context to unblock proxy and workers
if zmq_context:
zmq_context.term()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

proxy_thread doesnot check shuntdown_event proactively. Maybe we should close zmq_context first, if the proxy_thread depends on the zmq_context.
Btw, is it necessary for us to explicitly close these threads to ensure complete deallocation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed it's not very necessary, but I believe the explicit destructor can help to make the log clearer when we press control + C..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another thing is the underlying C++ codes can leverage these explicit term signal to properly exit, rather than blocking the terminal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I am concerned that merely using a semaphore shutdown_event to exit threads may not be thorough enough.

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
@ascend-robot
Copy link

CLA Signature Pass

0oshowero0, thanks for your pull request. All authors of the commits have signed the CLA. 👍

@0oshowero0 0oshowero0 merged commit ba5710e into Ascend:main Feb 28, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants